"""
Phase 8: Probe the 69→140 attenuation of Z₁.

Diagnostics to understand why expanding from 69 to 140 countries weakens Z₁
under full EBA controls (p=0.017 → p=0.160 when adding health_exp, life_exp,
expected_growth).

Six probes:
  1. Control-by-control attenuation path (which control kills Z₁?)
  2. Subsample coefficient heterogeneity (original 69 vs added 72)
  3. DFBETA influence diagnostics (which countries move Z₁ most?)
  4. Regional stepwise addition (where does Z₁ become unstable?)
  5. Noise vs signal decomposition (β shrinkage vs SE inflation?)
  6. Bad controls test (are health_exp, life_exp, expected_growth mediators?)

Output: 6 markdown tables in followup/output/tables/
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
from scipy import stats

sys.path.insert(0, str(Path("/mnt/c/demographics_capital_flows/multilateral/followup")))
from src.model import PanelGLS
from src.macro import (
    EBA_COUNTRIES, SSA_COUNTRIES, EU_EXPANSION, EXPANSION_TIER1,
    filter_eba_sample,
)

BASE_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup")
DATA_DIR = BASE_DIR / "data" / "processed"
OUTPUT_DIR = BASE_DIR / "output" / "tables"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEMO_VARS = ['Z_1', 'Z_2', 'Z_3']

# Simple 4-control spec (significant at 5%)
SIMPLE_CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']

# Full 7-control EBA spec (insignificant)
FULL_CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen',
                 'expected_growth', 'health_exp_gdp', 'life_expectancy']

# Original 69 = EBA-49 + SSA-20
ORIGINAL_69 = set(EBA_COUNTRIES + SSA_COUNTRIES)

# Added 72 = EU-10 + EXPANSION_TIER1
ADDED_COUNTRIES = set(EU_EXPANSION + EXPANSION_TIER1)

# Region assignments for added countries
REGION_MAP = {}
for iso in EU_EXPANSION:
    REGION_MAP[iso] = 'EU-10'
# Asia
for iso in ['BGD', 'VNM', 'KHM', 'MMR', 'LKA', 'NPL', 'LAO', 'BTN', 'MNG']:
    REGION_MAP[iso] = 'S/SE Asia'
# MENA / Central Asia
for iso in ['IRN', 'IRQ', 'KAZ', 'UZB', 'QAT', 'KWT', 'OMN', 'JOR', 'DZA', 'TUN',
            'BHR', 'TKM', 'KGZ', 'TJK', 'YEM', 'LBN']:
    REGION_MAP[iso] = 'MENA/C.Asia'
# LatAm
for iso in ['DOM', 'ECU', 'GTM', 'VEN', 'CRI', 'URY', 'BOL', 'PRY', 'HND', 'JAM']:
    REGION_MAP[iso] = 'LatAm'
# E. Europe / Caucasus
for iso in ['UKR', 'BLR', 'GEO', 'ARM', 'AZE', 'ALB', 'MDA', 'MKD', 'BIH']:
    REGION_MAP[iso] = 'E.Eur/Caucasus'
# SSA extension
for iso in ['GIN', 'MLI', 'BEN', 'TCD', 'GNQ', 'TGO', 'SLE', 'GAB',
            'SWZ', 'LBR', 'BDI', 'CAF', 'CPV', 'LSO', 'GNB', 'SYC', 'COM']:
    REGION_MAP[iso] = 'SSA-ext'
# Other
REGION_MAP['SDN'] = 'Other'


def load_panel():
    """Load corrected panel, filter year <= 2024."""
    df = pd.read_csv(DATA_DIR / "full_panel.csv")
    df = df[df['year'] <= 2024].copy()
    print(f"Loaded panel: {len(df):,} obs, {df['iso3'].nunique()} countries, "
          f"years {df['year'].min()}-{df['year'].max()}")
    return df


def stars(p):
    if p < 0.01:
        return '***'
    elif p < 0.05:
        return '**'
    elif p < 0.1:
        return '*'
    return ''


def run_gls(df, controls, label=''):
    """Run PanelGLS on ca_gdp ~ Z + controls. Return (model, est_df) or (None, None)."""
    all_vars = DEMO_VARS + controls
    est = df.dropna(subset=['ca_gdp'] + all_vars + ['iso3', 'year']).copy()
    if len(est) < 50:
        print(f"  {label}: insufficient obs ({len(est)})")
        return None, None

    model = PanelGLS()
    model.fit(est['ca_gdp'].values, est[all_vars].values,
              est['iso3'].values, est['year'].values)
    model.feature_names = all_vars
    return model, est


def extract_z1(model):
    """Extract Z₁ coefficient, SE, p-value from a fitted model."""
    if model is None:
        return np.nan, np.nan, np.nan
    idx = 0  # Z_1 is always first regressor
    return model.beta[idx], model.se[idx], model.pvalues[idx]


def save_table(lines, filename):
    """Write markdown table lines to file."""
    path = OUTPUT_DIR / filename
    with open(path, 'w') as f:
        f.write('\n'.join(lines) + '\n')
    print(f"  Saved: {path}")


# =========================================================================
# Probe 1: Control-by-Control Attenuation Path
# =========================================================================

def probe1_attenuation_path(df):
    """Add controls one at a time on both 69-country and 140-country samples."""
    print("\n" + "=" * 70)
    print("PROBE 1: Control-by-Control Attenuation Path")
    print("=" * 70)

    df69 = df[df['iso3'].isin(ORIGINAL_69)].copy()
    df140 = df.copy()

    # Sequence of controls to add
    control_sequence = [
        [],
        ['fiscal_bal_gdp'],
        ['fiscal_bal_gdp', 'nfa_gdp_lag'],
        ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw'],
        ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen'],
        ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen', 'expected_growth'],
        ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen', 'expected_growth', 'health_exp_gdp'],
        ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen', 'expected_growth', 'health_exp_gdp', 'life_expectancy'],
    ]

    labels = [
        'Z only',
        '+ fiscal_bal',
        '+ nfa_lag',
        '+ log_rel_opw',
        '+ kaopen',
        '+ expected_growth',
        '+ health_exp',
        '+ life_expectancy',
    ]

    rows = []
    for label, controls in zip(labels, control_sequence):
        for sample_name, sample_df in [('69-country', df69), ('140-country', df140)]:
            model, est = run_gls(sample_df, controls, f"{label} ({sample_name})")
            b, se, p = extract_z1(model)
            n_obs = model.n_obs if model else 0
            n_ctry = model.n_countries if model else 0
            r2 = model.r_squared if model else np.nan
            rows.append({
                'step': label,
                'sample': sample_name,
                'Z1_coef': b,
                'Z1_SE': se,
                'Z1_p': p,
                'R2': r2,
                'N': n_obs,
                'Countries': n_ctry,
            })

    results = pd.DataFrame(rows)

    # Build markdown table
    lines = [
        '# Probe 1: Control-by-Control Attenuation Path',
        '',
        'Z₁ coefficient as controls are added sequentially.',
        '',
        '| Step | Sample | Z₁ Coef | Z₁ SE | Z₁ p-val | R² | N | Countries |',
        '|------|--------|---------|-------|----------|----|---|-----------|',
    ]
    for _, r in results.iterrows():
        sig = stars(r['Z1_p']) if not np.isnan(r['Z1_p']) else ''
        lines.append(
            f"| {r['step']} | {r['sample']} | "
            f"{r['Z1_coef']:.2f}{sig} | {r['Z1_SE']:.2f} | {r['Z1_p']:.3f} | "
            f"{r['R2']:.3f} | {r['N']} | {r['Countries']} |"
        )
    lines.append('')
    lines.append('*,**,*** denote significance at 10%, 5%, 1%.')

    save_table(lines, 'table_probe1_attenuation_path.md')
    return results


# =========================================================================
# Probe 2: Subsample Coefficient Heterogeneity
# =========================================================================

def probe2_heterogeneity(df):
    """Test whether Z₁ is structurally different in original 69 vs added 72."""
    print("\n" + "=" * 70)
    print("PROBE 2: Subsample Coefficient Heterogeneity")
    print("=" * 70)

    df['is_added'] = (~df['iso3'].isin(ORIGINAL_69)).astype(int)
    df69 = df[df['is_added'] == 0].copy()
    df_added = df[df['is_added'] == 1].copy()

    section_a_rows = []

    # Part A: Separate regressions on 69, added, and full 140
    for spec_name, controls in [('Simple (4 ctrls)', SIMPLE_CONTROLS),
                                 ('Full (7 ctrls)', FULL_CONTROLS)]:
        for sample_name, sample_df in [('Original 69', df69),
                                        ('Added countries', df_added),
                                        ('Full 140', df)]:
            model, est = run_gls(sample_df, controls, f"{spec_name} - {sample_name}")
            b, se, p = extract_z1(model)
            n_obs = model.n_obs if model else 0
            n_ctry = model.n_countries if model else 0
            section_a_rows.append({
                'spec': spec_name,
                'sample': sample_name,
                'Z1_coef': b,
                'Z1_SE': se,
                'Z1_p': p,
                'N': n_obs,
                'Countries': n_ctry,
            })

    # Part B: Interaction test on full 140
    # Z₁×I(added), Z₂×I(added), Z₃×I(added)
    interaction_rows = []
    for spec_name, controls in [('Simple (4 ctrls)', SIMPLE_CONTROLS),
                                 ('Full (7 ctrls)', FULL_CONTROLS)]:
        interact_vars = []
        for zv in DEMO_VARS:
            iname = f'{zv}_x_added'
            df[iname] = df[zv] * df['is_added']
            interact_vars.append(iname)

        all_vars = DEMO_VARS + controls + ['is_added'] + interact_vars
        est = df.dropna(subset=['ca_gdp'] + all_vars + ['iso3', 'year']).copy()
        if len(est) < 50:
            continue

        model = PanelGLS()
        model.fit(est['ca_gdp'].values, est[all_vars].values,
                  est['iso3'].values, est['year'].values)
        model.feature_names = all_vars

        # Extract interaction coefficients
        for iv in interact_vars:
            idx = all_vars.index(iv)
            interaction_rows.append({
                'spec': spec_name,
                'variable': iv,
                'coef': model.beta[idx],
                'SE': model.se[idx],
                'p': model.pvalues[idx],
            })

        # Joint F-test on the three interactions
        # H0: all three interaction coefficients = 0
        interact_indices = [all_vars.index(iv) for iv in interact_vars]
        beta_int = model.beta[interact_indices]
        # Use Wald test: β'(V⁻¹)β / q ~ F(q, n-k)
        # Approximate using individual t-stats
        f_stat = np.sum((model.beta[interact_indices] / model.se[interact_indices]) ** 2) / len(interact_indices)
        df1 = len(interact_indices)
        df2 = model.n_obs - len(all_vars) - 1
        f_pval = 1 - stats.f.cdf(f_stat, df1, df2)

        interaction_rows.append({
            'spec': spec_name,
            'variable': f'Joint F-test (q={df1})',
            'coef': f_stat,
            'SE': np.nan,
            'p': f_pval,
        })

    # Part C: Z₁ by region for added countries
    region_rows = []
    for spec_name, controls in [('Simple (4 ctrls)', SIMPLE_CONTROLS)]:
        for region in sorted(set(REGION_MAP.values())):
            region_isos = {iso for iso, r in REGION_MAP.items() if r == region}
            rdf = df[df['iso3'].isin(region_isos)].copy()
            model, est = run_gls(rdf, controls, f"{spec_name} - {region}")
            b, se, p = extract_z1(model)
            n_ctry = model.n_countries if model else 0
            region_rows.append({
                'region': region,
                'Z1_coef': b,
                'Z1_SE': se,
                'Z1_p': p,
                'Countries': n_ctry,
            })

    # Build markdown
    lines = [
        '# Probe 2: Subsample Coefficient Heterogeneity',
        '',
        '## A. Z₁ by Sample',
        '',
        '| Specification | Sample | Z₁ Coef | Z₁ SE | Z₁ p-val | N | Countries |',
        '|---------------|--------|---------|-------|----------|---|-----------|',
    ]
    for r in section_a_rows:
        sig = stars(r['Z1_p']) if not np.isnan(r['Z1_p']) else ''
        lines.append(
            f"| {r['spec']} | {r['sample']} | "
            f"{r['Z1_coef']:.2f}{sig} | {r['Z1_SE']:.2f} | {r['Z1_p']:.3f} | "
            f"{r['N']} | {r['Countries']} |"
        )

    lines.extend([
        '',
        '## B. Interaction Test: Z×I(added72) on Full 140',
        '',
        '| Specification | Variable | Coef | SE | p-val |',
        '|---------------|----------|------|----|-------|',
    ])
    for r in interaction_rows:
        sig = stars(r['p']) if not np.isnan(r['p']) else ''
        se_str = f"{r['SE']:.2f}" if not np.isnan(r['SE']) else '—'
        lines.append(
            f"| {r['spec']} | {r['variable']} | "
            f"{r['coef']:.2f}{sig} | {se_str} | {r['p']:.3f} |"
        )

    lines.extend([
        '',
        '## C. Z₁ by Region (Added Countries Only, Simple 4-Control Spec)',
        '',
        '| Region | Z₁ Coef | Z₁ SE | Z₁ p-val | Countries |',
        '|--------|---------|-------|----------|-----------|',
    ])
    for r in region_rows:
        if np.isnan(r['Z1_p']):
            lines.append(f"| {r['region']} | — | — | — | {r['Countries']} |")
        else:
            sig = stars(r['Z1_p'])
            lines.append(
                f"| {r['region']} | {r['Z1_coef']:.2f}{sig} | {r['Z1_SE']:.2f} | "
                f"{r['Z1_p']:.3f} | {r['Countries']} |"
            )

    save_table(lines, 'table_probe2_heterogeneity.md')
    return section_a_rows, interaction_rows, region_rows


# =========================================================================
# Probe 3: DFBETA Influence Diagnostics
# =========================================================================

def probe3_dfbeta(df):
    """Leave-one-country-out DFBETA for Z₁ on the full 140-country sample."""
    print("\n" + "=" * 70)
    print("PROBE 3: DFBETA Influence Diagnostics")
    print("=" * 70)

    rows = []
    for spec_name, controls in [('Simple (4 ctrls)', SIMPLE_CONTROLS),
                                 ('Full (7 ctrls)', FULL_CONTROLS)]:
        # Full model
        full_model, full_est = run_gls(df, controls, f"Full ({spec_name})")
        if full_model is None:
            continue
        b_full, se_full, p_full = extract_z1(full_model)
        n_full = full_model.n_obs

        countries = sorted(df['iso3'].unique())
        cutoff = 2 / np.sqrt(len(countries))

        for iso in countries:
            df_excl = df[df['iso3'] != iso].copy()
            model_excl, _ = run_gls(df_excl, controls, f"excl {iso}")
            if model_excl is None:
                continue
            b_excl, _, _ = extract_z1(model_excl)

            dfbeta = (b_full - b_excl) / se_full

            # Country summary stats
            cdf = df[df['iso3'] == iso]
            mean_ca = cdf['ca_gdp'].mean()
            mean_z1 = cdf['Z_1'].mean() if 'Z_1' in cdf.columns else np.nan
            n_obs_c = len(cdf.dropna(subset=['ca_gdp'] + DEMO_VARS + controls))

            group = 'Original-69' if iso in ORIGINAL_69 else 'Added-72'
            region = REGION_MAP.get(iso, 'Original')

            rows.append({
                'spec': spec_name,
                'iso3': iso,
                'group': group,
                'region': region,
                'DFBETA': dfbeta,
                'abs_DFBETA': abs(dfbeta),
                'Z1_full': b_full,
                'Z1_excl': b_excl,
                'mean_CA': mean_ca,
                'mean_Z1': mean_z1,
                'n_obs': n_obs_c,
                'flagged': abs(dfbeta) > cutoff,
            })

    results = pd.DataFrame(rows)

    # Build markdown (top 20 per spec)
    lines = [
        '# Probe 3: DFBETA Influence Diagnostics',
        '',
        f'Cutoff = 2/√n = {cutoff:.3f}',
        '',
    ]

    for spec_name in results['spec'].unique():
        spec_df = results[results['spec'] == spec_name].sort_values('abs_DFBETA', ascending=False)
        top20 = spec_df.head(20)
        n_flagged = spec_df['flagged'].sum()
        n_flagged_orig = spec_df[(spec_df['flagged']) & (spec_df['group'] == 'Original-69')].shape[0]
        n_flagged_added = spec_df[(spec_df['flagged']) & (spec_df['group'] == 'Added-72')].shape[0]

        lines.extend([
            f'## {spec_name}',
            '',
            f'Flagged: {n_flagged} total ({n_flagged_orig} original, {n_flagged_added} added)',
            '',
            '| Rank | ISO3 | Group | Region | DFBETA | Z₁(excl) | Mean CA | Mean Z₁ | N obs | Flagged |',
            '|------|------|-------|--------|--------|----------|---------|---------|-------|---------|',
        ])

        for rank, (_, r) in enumerate(top20.iterrows(), 1):
            flag = '⚑' if r['flagged'] else ''
            lines.append(
                f"| {rank} | {r['iso3']} | {r['group']} | {r['region']} | "
                f"{r['DFBETA']:+.3f} | {r['Z1_excl']:.2f} | "
                f"{r['mean_CA']:.2f} | {r['mean_Z1']:.3f} | {r['n_obs']} | {flag} |"
            )
        lines.append('')

    save_table(lines, 'table_probe3_dfbeta.md')
    return results


# =========================================================================
# Probe 4: Regional Stepwise Addition
# =========================================================================

def probe4_cumulative(df):
    """Cumulative build-up by region with both control specs."""
    print("\n" + "=" * 70)
    print("PROBE 4: Regional Stepwise Addition")
    print("=" * 70)

    # Define cumulative steps
    eba49 = set(EBA_COUNTRIES)
    ssa20 = set(SSA_COUNTRIES)
    eu10 = set(EU_EXPANSION)

    # Region subsets from EXPANSION_TIER1
    latam = {iso for iso, r in REGION_MAP.items() if r == 'LatAm'}
    mena = {iso for iso, r in REGION_MAP.items() if r == 'MENA/C.Asia'}
    s_se_asia = {iso for iso, r in REGION_MAP.items() if r == 'S/SE Asia'}
    e_eur = {iso for iso, r in REGION_MAP.items() if r == 'E.Eur/Caucasus'}
    ssa_ext = {iso for iso, r in REGION_MAP.items() if r == 'SSA-ext'}
    other = {iso for iso, r in REGION_MAP.items() if r == 'Other'}

    steps = [
        ('EBA-49', eba49),
        ('+ SSA-20 (=69)', eba49 | ssa20),
        ('+ EU-10', eba49 | ssa20 | eu10),
        ('+ LatAm', eba49 | ssa20 | eu10 | latam),
        ('+ MENA/C.Asia', eba49 | ssa20 | eu10 | latam | mena),
        ('+ S/SE Asia', eba49 | ssa20 | eu10 | latam | mena | s_se_asia),
        ('+ E.Eur/Caucasus', eba49 | ssa20 | eu10 | latam | mena | s_se_asia | e_eur),
        ('+ SSA-ext', eba49 | ssa20 | eu10 | latam | mena | s_se_asia | e_eur | ssa_ext),
        ('+ Other (=140)', eba49 | ssa20 | eu10 | latam | mena | s_se_asia | e_eur | ssa_ext | other),
    ]

    rows = []
    for step_name, country_set in steps:
        step_df = df[df['iso3'].isin(country_set)].copy()
        for spec_name, controls in [('Simple (4 ctrls)', SIMPLE_CONTROLS),
                                     ('Full (7 ctrls)', FULL_CONTROLS)]:
            model, est = run_gls(step_df, controls, f"{step_name} ({spec_name})")
            b, se, p = extract_z1(model)
            rows.append({
                'step': step_name,
                'spec': spec_name,
                'Z1_coef': b,
                'Z1_SE': se,
                'Z1_p': p,
                'R2': model.r_squared if model else np.nan,
                'N': model.n_obs if model else 0,
                'Countries': model.n_countries if model else 0,
            })

    results = pd.DataFrame(rows)

    # Markdown
    lines = [
        '# Probe 4: Regional Stepwise Addition',
        '',
        'Cumulative build-up from EBA-49 → full 140, both control specs.',
        '',
    ]

    for spec_name in ['Simple (4 ctrls)', 'Full (7 ctrls)']:
        sdf = results[results['spec'] == spec_name]
        lines.extend([
            f'## {spec_name}',
            '',
            '| Step | Z₁ Coef | Z₁ SE | Z₁ p-val | R² | N | Countries |',
            '|------|---------|-------|----------|----|---|-----------|',
        ])
        for _, r in sdf.iterrows():
            if np.isnan(r['Z1_p']):
                lines.append(f"| {r['step']} | — | — | — | — | {r['N']} | {r['Countries']} |")
            else:
                sig = stars(r['Z1_p'])
                lines.append(
                    f"| {r['step']} | {r['Z1_coef']:.2f}{sig} | {r['Z1_SE']:.2f} | "
                    f"{r['Z1_p']:.3f} | {r['R2']:.3f} | {r['N']} | {r['Countries']} |"
                )
        lines.append('')

    save_table(lines, 'table_probe4_cumulative.md')
    return results


# =========================================================================
# Probe 5: Noise vs Signal Decomposition
# =========================================================================

def probe5_noise_signal(df):
    """Compare β, SE, σ² between 69 and 140 samples."""
    print("\n" + "=" * 70)
    print("PROBE 5: Noise vs Signal Decomposition")
    print("=" * 70)

    df69 = df[df['iso3'].isin(ORIGINAL_69)].copy()
    df140 = df.copy()

    rows = []
    for spec_name, controls in [('Simple (4 ctrls)', SIMPLE_CONTROLS),
                                 ('Full (7 ctrls)', FULL_CONTROLS)]:
        all_vars = DEMO_VARS + controls

        # 69-country
        m69, est69 = run_gls(df69, controls, f"69 ({spec_name})")
        # 140-country
        m140, est140 = run_gls(df140, controls, f"140 ({spec_name})")

        if m69 is None or m140 is None:
            continue

        b69, se69, p69 = extract_z1(m69)
        b140, se140, p140 = extract_z1(m140)

        # Residual variance
        sigma2_69 = np.sum(m69.resid ** 2) / (m69.n_obs - len(all_vars) - 1)
        sigma2_140 = np.sum(m140.resid ** 2) / (m140.n_obs - len(all_vars) - 1)

        # Average leverage for Z₁ (first non-constant regressor)
        # leverage_i = x_i' (X'X)^{-1} x_i; average ≈ k/n
        # More useful: (X'X)^{-1}[0,0] gives variance contribution of Z₁
        import statsmodels.api as sm
        for sample_label, sample_est, model in [('69', est69, m69), ('140', est140, m140)]:
            X = sm.add_constant(sample_est[all_vars].values)
            try:
                XtX_inv = np.linalg.inv(X.T @ X)
                z1_leverage = XtX_inv[1, 1]  # Z₁ is index 1 (after constant)
            except np.linalg.LinAlgError:
                z1_leverage = np.nan
            if sample_label == '69':
                lev69 = z1_leverage
            else:
                lev140 = z1_leverage

        rows.append({
            'spec': spec_name,
            'metric': 'β(Z₁)',
            'val_69': b69,
            'val_140': b140,
            'delta': b140 - b69,
            'pct_change': (b140 - b69) / abs(b69) * 100 if b69 != 0 else np.nan,
        })
        rows.append({
            'spec': spec_name,
            'metric': 'SE(Z₁)',
            'val_69': se69,
            'val_140': se140,
            'delta': se140 - se69,
            'pct_change': (se140 - se69) / se69 * 100 if se69 != 0 else np.nan,
        })
        rows.append({
            'spec': spec_name,
            'metric': 'p-value(Z₁)',
            'val_69': p69,
            'val_140': p140,
            'delta': p140 - p69,
            'pct_change': np.nan,
        })
        rows.append({
            'spec': spec_name,
            'metric': 'σ² (residual var)',
            'val_69': sigma2_69,
            'val_140': sigma2_140,
            'delta': sigma2_140 - sigma2_69,
            'pct_change': (sigma2_140 - sigma2_69) / sigma2_69 * 100,
        })
        rows.append({
            'spec': spec_name,
            'metric': '(X\'X)⁻¹[Z₁,Z₁]',
            'val_69': lev69,
            'val_140': lev140,
            'delta': lev140 - lev69,
            'pct_change': (lev140 - lev69) / lev69 * 100 if lev69 != 0 else np.nan,
        })
        rows.append({
            'spec': spec_name,
            'metric': 'N obs',
            'val_69': m69.n_obs,
            'val_140': m140.n_obs,
            'delta': m140.n_obs - m69.n_obs,
            'pct_change': (m140.n_obs - m69.n_obs) / m69.n_obs * 100,
        })
        rows.append({
            'spec': spec_name,
            'metric': 'R²',
            'val_69': m69.r_squared,
            'val_140': m140.r_squared,
            'delta': m140.r_squared - m69.r_squared,
            'pct_change': np.nan,
        })

        # Diagnosis
        beta_pct = (b140 - b69) / abs(b69) * 100 if b69 != 0 else np.nan
        se_pct = (se140 - se69) / se69 * 100 if se69 != 0 else np.nan
        print(f"\n  {spec_name} diagnosis:")
        print(f"    β change: {b69:.2f} → {b140:.2f} ({beta_pct:+.1f}%)")
        print(f"    SE change: {se69:.2f} → {se140:.2f} ({se_pct:+.1f}%)")
        print(f"    σ² change: {sigma2_69:.4f} → {sigma2_140:.4f} ({(sigma2_140-sigma2_69)/sigma2_69*100:+.1f}%)")

        if abs(beta_pct) > abs(se_pct):
            print(f"    → Dominated by SIGNAL CHANGE (β shrinkage)")
        else:
            print(f"    → Dominated by NOISE INCREASE (SE inflation)")

    results = pd.DataFrame(rows)

    # Markdown
    lines = [
        '# Probe 5: Noise vs Signal Decomposition',
        '',
        'Comparing 69-country vs 140-country samples.',
        'SE² = σ² × (X\'X)⁻¹, so SE change decomposes into residual variance × leverage.',
        '',
    ]
    for spec_name in ['Simple (4 ctrls)', 'Full (7 ctrls)']:
        sdf = results[results['spec'] == spec_name]
        lines.extend([
            f'## {spec_name}',
            '',
            '| Metric | 69-country | 140-country | Δ | % change |',
            '|--------|------------|-------------|---|----------|',
        ])
        for _, r in sdf.iterrows():
            pct = f"{r['pct_change']:+.1f}%" if not np.isnan(r['pct_change']) else '—'
            lines.append(
                f"| {r['metric']} | {r['val_69']:.4f} | {r['val_140']:.4f} | "
                f"{r['delta']:+.4f} | {pct} |"
            )
        lines.append('')

    save_table(lines, 'table_probe5_noise_signal.md')
    return results


# =========================================================================
# Probe 6: Bad Controls Test
# =========================================================================

def probe6_bad_controls(df):
    """Test whether health_exp, life_expectancy, expected_growth are mediators."""
    print("\n" + "=" * 70)
    print("PROBE 6: Bad Controls Test")
    print("=" * 70)
    print("If Z₁ strongly predicts a control variable, that control may be a")
    print("mediator (part of the mechanism), not a confounder. Including it")
    print("removes the demographic CHANNEL, biasing Z₁ toward zero.")

    df69 = df[df['iso3'].isin(ORIGINAL_69)].copy()
    df140 = df.copy()

    # Test: Z → each potential mediator
    dep_vars = ['health_exp_gdp', 'life_expectancy', 'expected_growth', 'log_rel_opw']

    rows = []
    for dep_var in dep_vars:
        for sample_name, sample_df in [('69-country', df69), ('140-country', df140)]:
            est = sample_df.dropna(subset=[dep_var] + DEMO_VARS + ['iso3', 'year']).copy()
            if len(est) < 50:
                rows.append({
                    'dep_var': dep_var,
                    'sample': sample_name,
                    'Z1_coef': np.nan, 'Z1_SE': np.nan, 'Z1_p': np.nan,
                    'Z2_coef': np.nan, 'Z2_p': np.nan,
                    'Z3_coef': np.nan, 'Z3_p': np.nan,
                    'R2': np.nan, 'N': 0, 'Countries': 0,
                })
                continue

            model = PanelGLS()
            model.fit(est[dep_var].values, est[DEMO_VARS].values,
                      est['iso3'].values, est['year'].values)
            model.feature_names = DEMO_VARS

            rows.append({
                'dep_var': dep_var,
                'sample': sample_name,
                'Z1_coef': model.beta[0],
                'Z1_SE': model.se[0],
                'Z1_p': model.pvalues[0],
                'Z2_coef': model.beta[1],
                'Z2_p': model.pvalues[1],
                'Z3_coef': model.beta[2],
                'Z3_p': model.pvalues[2],
                'R2': model.r_squared,
                'N': model.n_obs,
                'Countries': model.n_countries,
            })

    results = pd.DataFrame(rows)

    # Markdown
    lines = [
        '# Probe 6: Bad Controls Test (Angrist & Pischke)',
        '',
        'If Z predicts a control variable, that control is a *mediator* (part of the',
        'demographic mechanism), not a *confounder*. Including mediators in the CA',
        'regression biases the demographic coefficient toward zero.',
        '',
        '## Z → Potential Mediator Regressions',
        '',
        '| Dep. Variable | Sample | Z₁ Coef | Z₁ SE | Z₁ p-val | R² | N | Countries |',
        '|---------------|--------|---------|-------|----------|----|---|-----------|',
    ]

    for _, r in results.iterrows():
        if np.isnan(r['Z1_p']):
            lines.append(
                f"| {r['dep_var']} | {r['sample']} | — | — | — | — | {r['N']} | {r['Countries']} |"
            )
        else:
            sig = stars(r['Z1_p'])
            lines.append(
                f"| {r['dep_var']} | {r['sample']} | "
                f"{r['Z1_coef']:.3f}{sig} | {r['Z1_SE']:.3f} | {r['Z1_p']:.4f} | "
                f"{r['R2']:.3f} | {r['N']} | {r['Countries']} |"
            )

    # Interpretation
    lines.extend([
        '',
        '## Interpretation',
        '',
        '- Variables where Z₁ is significant (p < 0.05) are likely **mediators**.',
        '- Including mediators in the CA regression removes the demographic *channel*,',
        '  not confounding — this is a "bad control" problem.',
        '- If life_expectancy and health_exp are mediators, the parsimonious spec',
        '  (4 controls, Z₁ p=0.017) is the appropriate specification.',
    ])

    save_table(lines, 'table_probe6_bad_controls.md')
    return results


# =========================================================================
# Main
# =========================================================================

def main():
    print("=" * 70)
    print("PHASE 8: Probe the 69→140 Attenuation")
    print("=" * 70)

    df = load_panel()

    # Filter to expanded sample (140 countries)
    df = filter_eba_sample(df, extended=True, expansion=True)

    # Run all probes
    p1 = probe1_attenuation_path(df)
    p2 = probe2_heterogeneity(df)
    p3 = probe3_dfbeta(df)
    p4 = probe4_cumulative(df)
    p5 = probe5_noise_signal(df)
    p6 = probe6_bad_controls(df)

    print("\n" + "=" * 70)
    print("DONE. All 6 probe tables written to:")
    print(f"  {OUTPUT_DIR}/table_probe[1-6]_*.md")
    print("=" * 70)


if __name__ == "__main__":
    main()
